import torch
import torch.distributions
from utils.datasets.paths import get_svhn_path
from utils.datasets.svhn import get_SVHN_labels
from utils.datasets.svhn_augmentation import get_SVHN_augmentation
from utils.datasets import TINY_LENGTH
from torch.utils.data import Dataset
from torch.utils.data import Sampler
from torchvision import datasets, transforms
import numpy as np

from .svhn_validation_extra_split import SVHNValidationExtraSplit
from .cifar_semi_tiny_partition import BalancedSampler
from .loading_utils import load_teacher_data
from utils.datasets.paths import get_tiny_images_files
from utils.datasets.tinyImages import _load_tiny_image, _preload_tiny_images

def get_svhn_tiny_partition(dataset_classifications_path, teacher_model, samples_per_class, svhn_extra_val_split,
                            class_tpr_min=None, od_exclusion_threshold=None, calibrate_temperature=False,
                            id_class_balanced=True, verbose_exclude=False, soft_labels=True, batch_size=128,
                            all_sampler=False, augm_type='default', subdivide_epochs=False,
                            num_workers=8,
                            id_config_dict=None, od_config_dict=None, ssl_config=None):

    model_confidences, _, class_thresholds, temperature = load_teacher_data(dataset_classifications_path, teacher_model,
                                                                         class_tpr_min=class_tpr_min,
                                                                         od_exclusion_threshold=od_exclusion_threshold,
                                                                         calibrate_temperature=calibrate_temperature,
                                                                         ssl_config=ssl_config)


    augm_config = {}
    transform = get_SVHN_augmentation(augm_type, config_dict=augm_config)

    if all_sampler:
        top_k_samples = 1e8
        verbose_exclude=True
    else:
        top_k_samples = samples_per_class

    top_dataset = SVHNPlusTinyImageSVHNExtraTopKPartition(model_confidences, samples_per_class=top_k_samples,
                                                          transform_base=transform, min_conf=class_thresholds,
                                                          temperature=temperature,
                                                          svhn_extra_val_split=svhn_extra_val_split,
                                                          soft_labels=soft_labels)

    if all_sampler:
        balanced_sampler = AllValidSampler(top_dataset, samples_per_class)
    else:
        balanced_sampler = BalancedSampler(top_dataset, subdivide_epochs)

    top_loader = torch.utils.data.DataLoader(top_dataset, sampler=balanced_sampler, batch_size=batch_size, num_workers=num_workers)

    top_k_indices = top_dataset.get_used_semi_indices(verbose_exclude)
    bottom_dataset = SVHNTinyImageBottomKPartition(model_confidences, top_k_indices, transform_base=transform,
                                                    temperature=temperature, soft_labels=soft_labels,
                                                   svhn_extra_val_split=svhn_extra_val_split)

    bottom_loader = torch.utils.data.DataLoader(bottom_dataset, shuffle=True, batch_size=batch_size, num_workers=1)

    if id_config_dict is not None:
        id_config_dict['Dataset'] ='SVHN-SSL'
        id_config_dict['Extra validation split'] = svhn_extra_val_split
        id_config_dict['Batch out_size'] = batch_size
        id_config_dict['Samples per class'] = samples_per_class
        id_config_dict['All Sampler'] = all_sampler
        id_config_dict['Soft labels'] = soft_labels
        id_config_dict['Class balanced'] = id_class_balanced
        id_config_dict['Augmentation'] = augm_config

    if od_config_dict is not None:
        od_config_dict['Dataset'] = 'TinyImagesPartition'
        od_config_dict['Batch out_size'] = batch_size
        od_config_dict['Verbose exclude'] = verbose_exclude
        od_config_dict['Augmentation'] = augm_config

    return top_loader, bottom_loader


class SVHNPlusTinyImageSVHNExtraTopKPartition(Dataset):
    def __init__(self, model_logits, samples_per_class, transform_base, min_conf,
                 svhn_extra_val_split=True,
                 temperature=1.0,soft_labels=True, preload=True):
        self.data_location = get_tiny_images_files(False)
        self.fileID = open(self.data_location, "rb")
        self.samples_per_class = samples_per_class
        self.soft_labels = soft_labels
        self.preload = preload
        self.temperature = temperature
        transform = transforms.Compose([
            transforms.ToPILImage(),
            transform_base])

        self.transform = transform

        if svhn_extra_val_split:
            svhn_path = get_svhn_path()
            self.svhn_extra = SVHNValidationExtraSplit(svhn_path, split='extra-split', transform=transform_base)
        else:
            svhn_path = get_svhn_path()
            self.svhn_extra = datasets.SVHN(svhn_path, split='extra', transform=transform_base)


        self.model_logits = model_logits
        predicted_max_conf, predicted_class = torch.max(torch.softmax(self.model_logits,dim=1), dim=1)

        #density_model confidences contain first all onfidences for 80M, then SVHN extra
        self.is_tinyImage = torch.cat( [torch.ones(TINY_LENGTH, dtype=torch.bool),
                                   torch.zeros(len(self.svhn_extra), dtype=torch.bool)])

        assert len(self.is_tinyImage) == model_logits.shape[0]

        class_labels = get_SVHN_labels()
        self.num_classes = len(class_labels)
        self.train_dataset = datasets.SVHN(svhn_path, split='train', transform=transform_base)

        self.num_train_samples = len(self.train_dataset)
        self.train_per_class = torch.zeros(self.num_classes, dtype=torch.long)

        self.train_class_idcs = []
        targets_tensor = torch.LongTensor(self.train_dataset.labels)
        for i in range(self.num_classes):
            train_i = torch.nonzero(targets_tensor == i, as_tuple=False).squeeze()
            self.train_class_idcs.append(train_i)
            self.train_per_class[i] = len(train_i)

        self.in_use_indices = []
        self.valid_indices = []
        self.semi_per_class = torch.zeros(self.num_classes, dtype=torch.long)

        min_sampels_per_class = int(1e13)
        max_samples_per_class = 0

        for i in range(self.num_classes):
            min_conf_flag = predicted_max_conf >= min_conf[i]
            included_correct_class_bool_idcs = (predicted_class == i)  & min_conf_flag

            included_correct_class_linear_idcs = torch.nonzero(included_correct_class_bool_idcs, as_tuple=False).squeeze()
            included_correct_class_confidences = predicted_max_conf[included_correct_class_bool_idcs]
            included_correct_class_sort_idcs = torch.argsort(included_correct_class_confidences, descending=True)

            num_samples_i = int( min( samples_per_class, len(included_correct_class_linear_idcs) ))
            class_i_idcs = included_correct_class_linear_idcs[included_correct_class_sort_idcs[: num_samples_i]]

            self.valid_indices.append(included_correct_class_linear_idcs)

            self.in_use_indices.append(class_i_idcs)
            self.semi_per_class[i] = len(class_i_idcs)

            min_sampels_per_class = min(min_sampels_per_class, len(class_i_idcs))
            max_samples_per_class = max(max_samples_per_class, len(class_i_idcs))

            if num_samples_i < samples_per_class:
                print(f'Incomplete class {class_labels[i]} - Target count: {samples_per_class} - Found samples {len(class_i_idcs)}')

        self.num_semi_samples = 0
        self.length = self.num_train_samples
        for i in range(self.num_classes):
            self.num_semi_samples += self.semi_per_class[i]
            self.length +=  self.semi_per_class[i]

        #internal idx ranges
        self.train_idx_ranges = []
        self.semi_idx_ranges = []

        train_idx_start = 0
        semi_idx_start = self.num_train_samples
        for i in range(self.num_classes):
            i_train_samples = self.train_per_class[i]
            i_semi_samples = self.semi_per_class[i]

            train_idx_next = train_idx_start + i_train_samples
            semi_idx_next = semi_idx_start + i_semi_samples
            self.train_idx_ranges.append( (train_idx_start, train_idx_next))
            self.semi_idx_ranges.append( (semi_idx_start, semi_idx_next))

            train_idx_start = train_idx_next
            semi_idx_start = semi_idx_next

        self.cum_train_lengths = torch.cumsum(self.train_per_class, dim=0)
        self.cum_semi_lengths = torch.cumsum(self.semi_per_class, dim=0)

        print(f'Top K -  Temperature {self.temperature} - Soft labels {soft_labels}'
              f'  -  Target Samples per class { self.samples_per_class} - Train Samples {self.num_train_samples}')
        print(f'Min Semi Samples {min_sampels_per_class} - Max Semi samples {max_samples_per_class}'
              f' - Total semi samples {self.num_semi_samples} - Total length {self.length}')

        if preload:
            print(f'Preloading images')
            self.class_data = []
            for class_idx in range(self.num_classes):
                cls_in_use = self.in_use_indices[class_idx]
                cls_in_use_80m = cls_in_use[self.is_tinyImage[cls_in_use]]

                cls_imgs = np.zeros((len(cls_in_use), 32, 32, 3), dtype='uint8')
                cls_imgs_80m = _preload_tiny_images(cls_in_use_80m, self.fileID)
                cls_imgs[self.is_tinyImage[cls_in_use].numpy()] = cls_imgs_80m
                self.class_data.append(cls_imgs)


    #if verbose exclude, include all indices that fulfill the conf requirement but that are outside of the top-k range
    def get_used_semi_indices(self, verbose_exclude=False):
        if verbose_exclude:
            return torch.cat(self.valid_indices)
        else:
            return torch.cat(self.in_use_indices)

    def _load_train_image(self, cifar_idx):
        img, label = self.train_dataset[cifar_idx]
        if self.soft_labels:
            one_hot_label = torch.zeros(self.num_classes)
            one_hot_label[label] = 1.0
            return img, one_hot_label
        else:
            return img, label

    def _load_tiny_svhn_extra_image(self, class_idx, tiny_lin_idx):
        valid_index = self.in_use_indices[class_idx][tiny_lin_idx].item()
        is_tiny_img = self.is_tinyImage[valid_index]

        if is_tiny_img:
            if self.preload:
                img = self.class_data[class_idx][tiny_lin_idx, :]
            else:
                img = _load_tiny_image(valid_index, self.fileID)

            if self.transform is not None:
                img = self.transform(img)
        else:
            svhn_exta_idx = valid_index - TINY_LENGTH
            img, _ = self.svhn_extra[svhn_exta_idx]

        if self.soft_labels:
            label = torch.softmax(self.model_logits[valid_index, :] / self.temperature, dim=0)
        else:
            label = torch.argmax(self.model_logits[valid_index, :]).item()
        return img, label

    def __getitem__(self, index):
            if index < self.num_train_samples:
                class_idx = torch.nonzero(self.cum_train_lengths > index, as_tuple=False)[0]
                if class_idx > 0:
                    sample_idx = index - self.cum_train_lengths[class_idx - 1]
                else:
                    sample_idx = index
                train_class_idx = self.train_class_idcs[class_idx][sample_idx]
                return self._load_train_image(train_class_idx)
            else:
                index_semi = index - self.num_train_samples
                class_idx = torch.nonzero(self.cum_semi_lengths > index_semi, as_tuple=False)[0]
                if class_idx > 0:
                    sample_idx = index_semi - self.cum_semi_lengths[class_idx - 1]
                else:
                    sample_idx = index_semi

                return self._load_tiny_svhn_extra_image(class_idx, sample_idx)

    def __len__(self):
        return self.length

class AllValidSampler(Sampler):
    def __init__(self, svhn_top_k_partition, semi_samples_per_class):
        super().__init__(None)
        self.semi_per_class = svhn_top_k_partition.semi_per_class
        self.train_per_class = svhn_top_k_partition.train_per_class
        self.train_idx_ranges = svhn_top_k_partition.train_idx_ranges
        self.semi_idx_ranges = svhn_top_k_partition.semi_idx_ranges

        self.total_per_class = self.semi_per_class + self.train_per_class
        self.samples_per_class = torch.max(self.train_per_class + semi_samples_per_class)
        min_per_class = torch.min(self.total_per_class)
        max_per_class = torch.max(self.total_per_class)
        self.num_classes = len(self.semi_per_class)
        self.length = self.num_classes * self.samples_per_class

        print(f'All Valid Sampler: Samples/Class {self.samples_per_class} '
              f'- Max {max_per_class} - Min {min_per_class} - Length {self.length}')

    def __iter__(self):
        intra_class_idcs = []
        for i in range(self.num_classes):
            i_intra_idcs = torch.zeros(self.samples_per_class, dtype=torch.long)

            i_train_start, i_train_end = self.train_idx_ranges[i]
            i_semi_start, i_semi_end = self.semi_idx_ranges[i]

            i_collected_samples = min(self.total_per_class[i], self.samples_per_class)
            i_collected_semi_samples = i_collected_samples - self.train_per_class[i]
            i_intra_idcs[:self.train_per_class[i]] = torch.arange(i_train_start, i_train_end, dtype=torch.long)
            i_intra_idcs[self.train_per_class[i]:i_collected_samples] = \
                torch.arange(i_semi_start, i_semi_end, dtype=torch.long)[torch.randperm(self.semi_per_class[i])[:i_collected_semi_samples]]

            while i_collected_samples < self.samples_per_class:
                i_all_idcs = torch.zeros(self.total_per_class[i], dtype=torch.long)
                i_all_idcs[:self.train_per_class[i]] = torch.arange(i_train_start, i_train_end, dtype=torch.long)
                i_all_idcs[self.train_per_class[i]:] = torch.arange(i_semi_start, i_semi_end, dtype=torch.long)

                samples_to_get = min(self.samples_per_class - i_collected_samples, self.total_per_class[i])
                next_samples = i_all_idcs[torch.randperm(self.total_per_class[i])[:samples_to_get]]

                i_intra_idcs[i_collected_samples:(i_collected_samples + samples_to_get)] = next_samples
                i_collected_samples = i_collected_samples + samples_to_get

            intra_class_idcs.append(i_intra_idcs)

        idcs = torch.cat(intra_class_idcs)[torch.randperm(self.length)]

        return iter(idcs)

    def __len__(self):
        return self.length

# class BalancedSampler(Sampler):
#     def __init__(self, svhn_top_k_partition):
#         super().__init__(None)
#         self.semi_per_class = svhn_top_k_partition.semi_per_class
#         self.train_per_class = svhn_top_k_partition.train_per_class
#         self.train_idx_ranges = svhn_top_k_partition.train_idx_ranges
#         self.semi_idx_ranges = svhn_top_k_partition.semi_idx_ranges
#
#         self.total_per_class = self.semi_per_class + self.train_per_class
#         self.samples_per_class_per_split = torch.max(self.total_per_class)
#         min_per_class = torch.min(self.total_per_class)
#         self.num_classes = len(self.semi_per_class)
#         self.length = self.num_classes * self.samples_per_class_per_split
#
#         print(f'Balanced Sampler: Max {self.samples_per_class_per_split} - Min {min_per_class} - Length {self.length}')
#
#     def __iter__(self):
#         intra_class_idcs = []
#         for i in range(self.num_classes):
#             i_all_idcs = torch.zeros(self.total_per_class[i], dtype=torch.long)
#             i_intra_idcs = torch.zeros(self.samples_per_class_per_split, dtype=torch.long)
#
#             i_train_start, i_train_end = self.train_idx_ranges[i]
#             i_all_idcs[:self.train_per_class[i]] = torch.arange(i_train_start, i_train_end, dtype=torch.long)
#
#             i_semi_start, i_semi_end = self.semi_idx_ranges[i]
#             i_all_idcs[self.train_per_class[i]:] = torch.arange(i_semi_start, i_semi_end, dtype=torch.long)
#
#             i_collected_samples = 0
#             while i_collected_samples < self.samples_per_class_per_split:
#                 samples_to_get = min(self.samples_per_class_per_split - i_collected_samples, self.total_per_class[i])
#                 next_samples = i_all_idcs[torch.randperm(self.total_per_class[i])[:samples_to_get]]
#
#                 i_intra_idcs[i_collected_samples:(i_collected_samples + samples_to_get)] = next_samples
#                 i_collected_samples = i_collected_samples + samples_to_get
#
#             intra_class_idcs.append(i_intra_idcs)
#
#         idcs = torch.cat(intra_class_idcs)[torch.randperm(self.length)]
#
#         return iter(idcs)
#
#     def __len__(self):
#         return self.length

class SVHNTinyImageBottomKPartition(Dataset):
    def __init__(self, model_logits, top_k_indices, transform_base, temperature=1.0,
                 soft_labels=True, svhn_extra_val_split=True):
        self.data_location = get_tiny_images_files(False)
        self.fileID = open(self.data_location, "rb")
        self.soft_labels = soft_labels
        self.temperature = temperature
        transform = transforms.Compose([
            transforms.ToPILImage(),
            transform_base])

        self.transform = transform

        if svhn_extra_val_split:
            svhn_path = get_svhn_path()
            self.svhn_extra = SVHNValidationExtraSplit(svhn_path, split='extra-split', transform=transform_base)
        else:
            svhn_path = get_svhn_path()
            self.svhn_extra = datasets.SVHN(svhn_path, split='extra', transform=transform_base)

        transform = transforms.Compose([
            transforms.ToPILImage(),
            transform_base])

        self.transform = transform
        self.model_logits = model_logits
        self.num_classes = model_logits.shape[1]
        self.is_tinyImage = torch.cat( [torch.ones(TINY_LENGTH, dtype=torch.bool),
                                   torch.zeros(len(self.svhn_extra), dtype=torch.bool)])

        assert len(self.is_tinyImage) == model_logits.shape[0]

        #in_use_indices [i] holds all valid indices for i-th confidence interval
        self.valid_indices = []

        valid_bool_indices  = torch.ones(self.model_logits.shape[0], dtype=torch.bool)
        valid_bool_indices[top_k_indices] = 0
        self.valid_indices = torch.nonzero(valid_bool_indices, as_tuple=False).squeeze()

        self.length = len(self.valid_indices)

        print(f'Samples {self.length} - Temperature {self.temperature}')

    def __getitem__(self, index):
        valid_index = self.valid_indices[index]
        is_tiny_img = self.is_tinyImage[valid_index]

        if is_tiny_img:
            img = _load_tiny_image(valid_index, self.fileID)
            if self.transform is not None:
                img = self.transform(img)
        else:
            svhn_exta_idx = valid_index - TINY_LENGTH
            img, _ = self.svhn_extra[svhn_exta_idx]

        if self.soft_labels:
            model_prediction = torch.softmax(self.model_logits[valid_index, :] / self.temperature, dim=0)
        else:
            model_prediction = (1./self.num_classes) * torch.ones(self.num_classes)

        return img, model_prediction

    def __len__(self):
        return self.length
